sup-extra HDU-1402 A*B Problem Plus FFT模板题
#HDU-1402 A*B Problem Plus FFT模板题
题目就是一个大数乘法,由于位数长度达到了50000级别所以不能用常规写法,需要使用FFT解题。
那么首先需要知道的问题是,FFT是什么,是拿来干什么的?
FFT,即为快速傅氏变换,是离散傅氏变换的快速算法,它是根据离散傅氏变换的奇、偶、虚、实等特性,对离散傅立叶变换的算法进行改进获得的。
看上去很复杂,但实际上我们使用FFT是要解决的就是多项式乘法。多项式的表示有两种:
1.系数表示法,形如f(x)=a0+a1X+a2X2+…anXn,其中告诉你所有的a的值,就是系数表示法
2.点值表示法,利用(x0,y0) (x1,y1)…(Xn,yn)这么n个点来约束的n次多项式,你知道所有的x,y的值,就是点值表示法
如果F(x)=f(x)*g(x) 那么F(x0)=f(x0) * g(x0) 这个结论看上去就很自然。所以对于n+1个点我们都可以通过算出来的f(x) g(x)求出相应的F(x)。也就是说我们只要知道f(x) g(x)的点值表示法,算出n+1个点,就能得到乘法乘出来的F(x) 函数的点值表示法。
所以为了求出点值表示法的函数,我们就需要先找n+1个x的值出来。为了便于计算xk 我们选取的就是复数平面上单位元的点,这样k次方得到的就是1了。FFT实现的就是把系数表示法变成点值表示法,把点值表示法变成系数表示法。
至于FFT可以变成2个DFT啊、求和的分治算法为了防止爆栈而变成下标反向的循环运算啊、IDFT和DFT之间只差一个1/n这种细节就不用理解了,还有诸如“你把一个函数扩展到了复数平面上后结果要是有虚部怎么办的问题”啊,我只能说理论上两个纯实部的函数就算代进虚部的函数值生成的新函数也不会产生有虚部的变量,至于为什么只能说是一些奇妙的性质了,总之明白了FFT是把什么东西变成了什么之后你甚至不用知道原理,拿着板子用就行。
附上代码:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <map>
#include <cmath>
#define maxl (1<<16)
#define pi 3.141592653589793238462643383
using namespace std;
struct complex
{
double re,im;
complex (double r=0.0,double i=0.0){re=r;im=i;}
}a[maxl*2],b[maxl*2],w[2][maxl*2];
complex operator +(const complex&x,const complex&y)
{
return complex(x.re+y.re,x.im+y.im);
}
complex operator -(const complex&x,const complex&y)
{
return complex(x.re-y.re,x.im-y.im);
}
complex operator *(const complex&x,const complex&y)
{
return complex(x.re*y.re-x.im*y.im,x.im*y.re+x.re*y.im);
}
int n,na,nb,rev[maxl*2],res[maxl*2];
char oria[maxl],orib[maxl];
void init()//calculate reverse
{
for(int i=0;i<n;++i){rev[i]=(rev[i>>1]>>1)|(i&1)<<(len-1);}
for(int i=0;i<n;i++)
{
w[0][i]=w[1][i]=complex(cos(2*pi*i/n),sin(2*pi*i/n));
w[1][i].im=-w[0][i].im;
}
}
void FFT(complex *a,int order)
{
complex x,y;
for(int i=0;i<n;i++)
{
if(i<rev[i])swap(a[i],a[rev[i]]);
}
for(int i=1;i<n;i<<=1)
{
for(int j=0,t=n/(i<<1);j<n;j+=i<<1)
for(int k=0,l=0;k<i;k++,l+=t)
{
x=w[order][l]*a[j+k+i];
y=a[j+k];
a[j+k]=y+x;
a[j+k+i]=y-x;
}
}
if(order)for(int i=0;i<n;i++) a[i].re/=n;
}
int main(void)
{
char ch;
while(~scanf("%s%s",oria,orib))
{
na=nb=0;
while(oria[na]!='\0')
{
a[na].re=oria[na]-'0';
a[na].im=0;
na++;
}
while(orib[nb]!='\0')
{
b[nb].re=orib[nb]-'0';
b[nb].im=0;
nb++;
}
n=1;
while(n<na||n<nb) n<<=1;
n<<=1;
//cout<<"n:"<<n<<endl;
init();
//cout<<"done"<<endl;
FFT(a,0);
FFT(b,0);
for(int i=0;i<n;i++) a[i]=a[i]*b[i];
FFT(a,1);
memset(res,0,sizeof(res));
for(int i=0;i<n;i++) {
res[i]=int(a[i].re+0.5);
}
for(int i=n-1;i>0;i--)
{
res[i-1]+=res[i]/10;
res[i]%=10;
//cout<<res[i];
}
int resn=na+nb-1;
for(int i=0;i<resn;i++)
{
if(res[i]!=0)
{
for(;i<resn;i++)
printf("%d",res[i]);
break;
}
else if(i==resn-1) printf("0");
}
printf("\n");
for(int i=0;i<n;i++)
{
a[i].re=a[i].im=b[i].re=b[i].im=0;
}
/*for(int i=0;i<n;i++)
printf("%lf%lf\n",a[i].re,a[i].im);*/
}
return 0;
}